{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xLOXFOT5Q40E"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "iiQkM5ZgQ8r2"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j6331ZSsQGY3"
      },
      "source": [
        "# MNIST の分類"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "i9Jcnb8bQQyd"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://www.tensorflow.org/quantum/tutorials/mnist\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\">TensorFlow.org で表示</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ja/quantum/tutorials/mnist.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\"> Google Colab で実行</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ja/quantum/tutorials/mnist.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\">GitHub でソースを表示{</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ja/quantum/tutorials/mnist.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\">ノートブックをダウンロード/a0}</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "udLObUVeGfTs"
      },
      "source": [
        "このチュートリアルでは、簡略化された MNIST バージョンを分類する、<a href=\"https://arxiv.org/pdf/1802.06002.pdf\" class=\"external\">Farhi et al</a> で使用されたアプローチに似た量子ニューラルネットワーク（QNN）を構築し、古典的なデータ問題における量子ニューラルネットワークのパフォーマンスを従来のニューラルネットワークと比較します。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X35qHdh5Gzqg"
      },
      "source": [
        "## セットアップ"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TorxE5tnkvb2"
      },
      "outputs": [],
      "source": [
        "!pip install tensorflow==2.3.1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FxkQA6oblNqI"
      },
      "source": [
        "TensorFlow Quantum をインストールします。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "saFHsRDpkvkH"
      },
      "outputs": [],
      "source": [
        "!pip install tensorflow-quantum"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hdgMMZEBGqyl"
      },
      "source": [
        "次に、TensorFlow とモジュールの依存関係をインポートします。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "enZ300Bflq80"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import tensorflow_quantum as tfq\n",
        "\n",
        "import cirq\n",
        "import sympy\n",
        "import numpy as np\n",
        "import seaborn as sns\n",
        "import collections\n",
        "\n",
        "# visualization tools\n",
        "%matplotlib inline\n",
        "import matplotlib.pyplot as plt\n",
        "from cirq.contrib.svg import SVGCircuit"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b08Mmbs8lr81"
      },
      "source": [
        "## 1. データを読み込む\n",
        "\n",
        "このチュートリアルでは、<a href=\"https://arxiv.org/pdf/1802.06002.pdf\" class=\"external\">Farhi et al.</a> に従って、数字の 3 と 6 を区別する二項分類器を構築します。このセクションでは、次を行うデータ処理を説明します。\n",
        "\n",
        "- Keras から生データを読み込みます。\n",
        "- データセットを 3 と 6 に絞り込みます。\n",
        "- 画像が量子コンピュータに適合するように、画像を縮小します。\n",
        "- 矛盾するサンプルを取り除きます。\n",
        "- バイナリ画像を Cirq 回路に変換します。\n",
        "- Cirq 回路を TensorFlow Quantum 回路に変換します。 "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pDUdGxn-ojgy"
      },
      "source": [
        "### 1.1 生データを読み込む"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xZyGXlaKojgz"
      },
      "source": [
        "Keras で配布された MNIST データセットを読み込みます。 "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "d9OSExvCojg0"
      },
      "outputs": [],
      "source": [
        "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
        "\n",
        "# Rescale the images from [0,255] to the [0.0,1.0] range.\n",
        "x_train, x_test = x_train[..., np.newaxis]/255.0, x_test[..., np.newaxis]/255.0\n",
        "\n",
        "print(\"Number of original training examples:\", len(x_train))\n",
        "print(\"Number of original test examples:\", len(x_test))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fZpbygdGojg3"
      },
      "source": [
        "3 と 6 の数字のみを保持してほかのクラスを取り除くように、データセットを絞り込みます。同時に、`3` を `True`、6 を `False` というように、ラベル `y` をブール型に変換します。 "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hOw68cCZojg4"
      },
      "outputs": [],
      "source": [
        "def filter_36(x, y):\n",
        "    keep = (y == 3) | (y == 6)\n",
        "    x, y = x[keep], y[keep]\n",
        "    y = y == 3\n",
        "    return x,y"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "p-XEU8egGL6q"
      },
      "outputs": [],
      "source": [
        "x_train, y_train = filter_36(x_train, y_train)\n",
        "x_test, y_test = filter_36(x_test, y_test)\n",
        "\n",
        "print(\"Number of filtered training examples:\", len(x_train))\n",
        "print(\"Number of filtered test examples:\", len(x_test))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3wyiaP0Xojg_"
      },
      "source": [
        "最初のサンプルを表示します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "j5STP7MbojhA"
      },
      "outputs": [],
      "source": [
        "print(y_train[0])\n",
        "\n",
        "plt.imshow(x_train[0, :, :, 0])\n",
        "plt.colorbar()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wNS9sVPQojhC"
      },
      "source": [
        "### 1.2 画像を縮小する"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fmmtplIFGL6t"
      },
      "source": [
        "現在の量子コンピュータでは、画像サイズ 28x28 は大きすぎるため、4x4 に縮小します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lbhUdBFWojhE"
      },
      "outputs": [],
      "source": [
        "x_train_small = tf.image.resize(x_train, (4,4)).numpy()\n",
        "x_test_small = tf.image.resize(x_test, (4,4)).numpy()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pOMd7zIjGL6x"
      },
      "source": [
        "サイズ変更を行ったら、もう一度最初のサンプルを表示します。 "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YIYOtCRIGL6y"
      },
      "outputs": [],
      "source": [
        "print(y_train[0])\n",
        "\n",
        "plt.imshow(x_train_small[0,:,:,0], vmin=0, vmax=1)\n",
        "plt.colorbar()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gGeF1_qtojhK"
      },
      "source": [
        "### 1.3 矛盾するサンプルを取り除く"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7ZLkq2yeojhL"
      },
      "source": [
        "<a href=\"https://arxiv.org/pdf/1802.06002.pdf\" class=\"external\">Farhi et al.</a> の *3.3 Learning to Distinguish Digits* セクションに説明されているように、データセットから両方のクラスに属するラベルが付けられた画像を取り除きます。\n",
        "\n",
        "これは標準的な機械学習の手順ではありませんが、論文の手順に従う目的で追加している手順です。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LqOPW0C7ojhL"
      },
      "outputs": [],
      "source": [
        "def remove_contradicting(xs, ys):\n",
        "    mapping = collections.defaultdict(set)\n",
        "    # Determine the set of labels for each unique image:\n",
        "    for x,y in zip(xs,ys):\n",
        "       mapping[tuple(x.flatten())].add(y)\n",
        "    \n",
        "    new_x = []\n",
        "    new_y = []\n",
        "    for x,y in zip(xs, ys):\n",
        "      labels = mapping[tuple(x.flatten())]\n",
        "      if len(labels) == 1:\n",
        "          new_x.append(x)\n",
        "          new_y.append(list(labels)[0])\n",
        "      else:\n",
        "          # Throw out images that match more than one label.\n",
        "          pass\n",
        "    \n",
        "    num_3 = sum(1 for value in mapping.values() if True in value)\n",
        "    num_6 = sum(1 for value in mapping.values() if False in value)\n",
        "    num_both = sum(1 for value in mapping.values() if len(value) == 2)\n",
        "\n",
        "    print(\"Number of unique images:\", len(mapping.values()))\n",
        "    print(\"Number of 3s: \", num_3)\n",
        "    print(\"Number of 6s: \", num_6)\n",
        "    print(\"Number of contradictory images: \", num_both)\n",
        "    print()\n",
        "    print(\"Initial number of examples: \", len(xs))\n",
        "    print(\"Remaining non-contradictory examples: \", len(new_x))\n",
        "    \n",
        "    return np.array(new_x), np.array(new_y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VMOiJfz_ojhP"
      },
      "source": [
        "取り除いた結果の数量は、レポートされている値に密に一致していませんが、これは正確な手順が指定されていないためです。\n",
        "\n",
        "また、この時点で、矛盾するサンプルをフィルタリングすることで、矛盾するトレーニングサンプルがモデルに絶対に送られないということではありません。次のデータの二項化のステップでは、さらに競合が発生します。 "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zpnsAssWojhP"
      },
      "outputs": [],
      "source": [
        "x_train_nocon, y_train_nocon = remove_contradicting(x_train_small, y_train)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SlJ5NVaPojhT"
      },
      "source": [
        "### 1.3 データを量子回路としてエンコードする\n",
        "\n",
        "<a href=\"https://arxiv.org/pdf/1802.06002.pdf\" class=\"external\">Farhi et al.</a> は、量子コンピュータを使って画像を処理するには、ピクセルの値に応じた状態で、各ピクセルをキュービットで表現するように提案しています。最初のステップでは、バイナリエンコーディングへの変換を行います。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1z8J7OyDojhV"
      },
      "outputs": [],
      "source": [
        "THRESHOLD = 0.5\n",
        "\n",
        "x_train_bin = np.array(x_train_nocon > THRESHOLD, dtype=np.float32)\n",
        "x_test_bin = np.array(x_test_small > THRESHOLD, dtype=np.float32)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SlJ5NVaPojhU"
      },
      "source": [
        "この時点で矛盾した画像を取り除く場合、193 個しか画像が残らず、これでは有効なトレーニングを行える数量とは言えません。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1z8J7OyDojhW"
      },
      "outputs": [],
      "source": [
        "_ = remove_contradicting(x_train_bin, y_train_nocon)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oLyxS9KlojhZ"
      },
      "source": [
        "しきい値を超える値を持つピクセルインデックスのキュービットは、$X$ ゲートを介して循環します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aOu_3-3ZGL61"
      },
      "outputs": [],
      "source": [
        "def convert_to_circuit(image):\n",
        "    \"\"\"Encode truncated classical image into quantum datapoint.\"\"\"\n",
        "    values = np.ndarray.flatten(image)\n",
        "    qubits = cirq.GridQubit.rect(4, 4)\n",
        "    circuit = cirq.Circuit()\n",
        "    for i, value in enumerate(values):\n",
        "        if value:\n",
        "            circuit.append(cirq.X(qubits[i]))\n",
        "    return circuit\n",
        "\n",
        "\n",
        "x_train_circ = [convert_to_circuit(x) for x in x_train_bin]\n",
        "x_test_circ = [convert_to_circuit(x) for x in x_test_bin]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zSCXqzOzojhd"
      },
      "source": [
        "これは最初のサンプルに作成された回路です（回路図には、ゼロゲートのキュービットは表示されていません）。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "w3POmUEUojhe"
      },
      "outputs": [],
      "source": [
        "SVGCircuit(x_train_circ[0])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AEQMxCcBojhg"
      },
      "source": [
        "この回路を、画像の値がしきい値を超えるインデックスを比較します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TBIsiXdtojhh"
      },
      "outputs": [],
      "source": [
        "bin_img = x_train_bin[0,:,:,0]\n",
        "indices = np.array(np.where(bin_img)).T\n",
        "indices"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mWZ24w1Oojhk"
      },
      "source": [
        "これらの `Cirq` 回路を `tfq` のテンソルに変換します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IZStEMk4ojhk"
      },
      "outputs": [],
      "source": [
        "x_train_tfcirc = tfq.convert_to_tensor(x_train_circ)\n",
        "x_test_tfcirc = tfq.convert_to_tensor(x_test_circ)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4USiqeOqGL67"
      },
      "source": [
        "## 2. 量子ニューラルネットワーク\n",
        "\n",
        "画像を分類する量子回路構造に関するガイダンスはほとんどありません。分類は読み出されるキュービットの期待に基づいて行われるため、<a href=\"https://arxiv.org/pdf/1802.06002.pdf\" class=\"external\">Farhi et al.</a> は、2 つのキュービットゲートを使用して、読み出しキュービットが必ず作用されるようにすることを提案しています。これはある意味、ピクセル全体に小さな<a href=\"https://arxiv.org/abs/1511.06464\" class=\"external\">ユニタリ RNN</a>を実行することに似ています。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "knIzawEeojho"
      },
      "source": [
        "### 2.1 モデル回路を構築する\n",
        "\n",
        "次の例では、このレイヤー化アプローチを説明しています。各レイヤーは同一ゲートの *n* 個のインスタンスを使用しており、各データキュービットは読み出しキュービットに影響を与えています。\n",
        "\n",
        "ゲートのレイヤーを回路に追加する簡単なクラスから始めましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-hjxxgU5ojho"
      },
      "outputs": [],
      "source": [
        "class CircuitLayerBuilder():\n",
        "    def __init__(self, data_qubits, readout):\n",
        "        self.data_qubits = data_qubits\n",
        "        self.readout = readout\n",
        "    \n",
        "    def add_layer(self, circuit, gate, prefix):\n",
        "        for i, qubit in enumerate(self.data_qubits):\n",
        "            symbol = sympy.Symbol(prefix + '-' + str(i))\n",
        "            circuit.append(gate(qubit, self.readout)**symbol)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Sjo5hANFojhr"
      },
      "source": [
        "サンプル回路レイヤーを構築して、どのようになるかを確認します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SzXWOpUGojhs"
      },
      "outputs": [],
      "source": [
        "demo_builder = CircuitLayerBuilder(data_qubits = cirq.GridQubit.rect(4,1),\n",
        "                                   readout=cirq.GridQubit(-1,-1))\n",
        "\n",
        "circuit = cirq.Circuit()\n",
        "demo_builder.add_layer(circuit, gate = cirq.XX, prefix='xx')\n",
        "SVGCircuit(circuit)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "T-QhPE1pojhu"
      },
      "source": [
        "では、2 レイヤーモデルを構築しましょう。 データ回路サイズに一致するようにし、準備と読み出し演算を含めます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JiALbpwRGL69"
      },
      "outputs": [],
      "source": [
        "def create_quantum_model():\n",
        "    \"\"\"Create a QNN model circuit and readout operation to go along with it.\"\"\"\n",
        "    data_qubits = cirq.GridQubit.rect(4, 4)  # a 4x4 grid.\n",
        "    readout = cirq.GridQubit(-1, -1)         # a single qubit at [-1,-1]\n",
        "    circuit = cirq.Circuit()\n",
        "    \n",
        "    # Prepare the readout qubit.\n",
        "    circuit.append(cirq.X(readout))\n",
        "    circuit.append(cirq.H(readout))\n",
        "    \n",
        "    builder = CircuitLayerBuilder(\n",
        "        data_qubits = data_qubits,\n",
        "        readout=readout)\n",
        "\n",
        "    # Then add layers (experiment by adding more).\n",
        "    builder.add_layer(circuit, cirq.XX, \"xx1\")\n",
        "    builder.add_layer(circuit, cirq.ZZ, \"zz1\")\n",
        "\n",
        "    # Finally, prepare the readout qubit.\n",
        "    circuit.append(cirq.H(readout))\n",
        "\n",
        "    return circuit, cirq.Z(readout)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2QZvVh7vojhx"
      },
      "outputs": [],
      "source": [
        "model_circuit, model_readout = create_quantum_model()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LY7vbY6yfABE"
      },
      "source": [
        "### 2.2 tfq-keras モデルでモデル回路をラップする\n",
        "\n",
        "量子コンポーネントで Keras モデルを構築します。このモデルには、古典的なデータをエンコードする「量子データ」が `x_train_circ` からフィードされます。*パラメータ化された量子回路*レイヤーの `tfq.layers.PQC` を使用して、量子データでモデル回路をトレーニングするモデルです。\n",
        "\n",
        "<a href=\"https://arxiv.org/pdf/1802.06002.pdf\" class=\"external\">Farhi et al.</a> は、画像を分類するには、パラメータ化された回路で読み出しキュービットの期待値を使用することを提案しています。期待値は、1 から -1 の値です。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZYdf_KOxojh0"
      },
      "outputs": [],
      "source": [
        "# Build the Keras model.\n",
        "model = tf.keras.Sequential([\n",
        "    # The input is the data-circuit, encoded as a tf.string\n",
        "    tf.keras.layers.Input(shape=(), dtype=tf.string),\n",
        "    # The PQC layer returns the expected value of the readout gate, range [-1,1].\n",
        "    tfq.layers.PQC(model_circuit, model_readout),\n",
        "])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jz-FbVc9ojh3"
      },
      "source": [
        "次に、`compile` メソッドを使用して、モデルにトレーニング手順を指定します。\n",
        "\n",
        "期待される読み出しは `[-1,1]` の範囲であるため、ヒンジ損失を最適化すると、ある程度自然な適合となります。\n",
        "\n",
        "注意: もう 1 つの有効なアプローチとして、出力範囲を `[0,1]` にシフトし、モデルがクラス `3` に割りてる確率として扱う方法があります。これは、標準的な`tf.losses.BinaryCrossentropy` 損失で使用することができます。\n",
        "\n",
        "ここでヒンジ損失を使用するには、小さな調整を 2 つ行う必要があります。まず、ラベル `y_train_nocon` をブール型からヒンジ損失が期待する `[-1,1]` に変換することです。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CgMNkC1Fojh5"
      },
      "outputs": [],
      "source": [
        "y_train_hinge = 2.0*y_train_nocon-1.0\n",
        "y_test_hinge = 2.0*y_test-1.0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5nwnveDiojh7"
      },
      "source": [
        "次に、`[-1, 1]` を `y_true` ラベル引数として正しく処理するカスタムの `hinge_accuracy` メトリックを使用します。`tf.losses.BinaryAccuracy(threshold=0.0)` は `y_true` がブール型であることを期待するため、ヒンジ損失とは使用できません。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3XKtZ_TEojh8"
      },
      "outputs": [],
      "source": [
        "def hinge_accuracy(y_true, y_pred):\n",
        "    y_true = tf.squeeze(y_true) > 0.0\n",
        "    y_pred = tf.squeeze(y_pred) > 0.0\n",
        "    result = tf.cast(y_true == y_pred, tf.float32)\n",
        "\n",
        "    return tf.reduce_mean(result)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FlpETlLRojiA"
      },
      "outputs": [],
      "source": [
        "model.compile(\n",
        "    loss=tf.keras.losses.Hinge(),\n",
        "    optimizer=tf.keras.optimizers.Adam(),\n",
        "    metrics=[hinge_accuracy])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jkHq2RstojiC"
      },
      "outputs": [],
      "source": [
        "print(model.summary())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lsuOzDYblA9s"
      },
      "source": [
        "### 量子モデルをトレーニングする\n",
        "\n",
        "では、モデルをトレーニングしましょう。これには約 45 分かかりますが、その時間を待てない方は、小規模なデータのサブセット（以下の`NUM_EXAMPLES=500` セット）を使用するとよいでしょう。トレーニング時のモデルの進捗にあまり影響はありません（パラメータは 32 しかなく、これらを制約する上であまりデータは必要ありません）。サンプル数を減らすことでトレーニングを早めに（5 分程度）終わらせることができますが、検証ログに進捗状況を示すには十分な長さです。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "n8vuQpSLlBV2"
      },
      "outputs": [],
      "source": [
        "EPOCHS = 3\n",
        "BATCH_SIZE = 32\n",
        "\n",
        "NUM_EXAMPLES = len(x_train_tfcirc)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qJnNG-3JojiI"
      },
      "outputs": [],
      "source": [
        "x_train_tfcirc_sub = x_train_tfcirc[:NUM_EXAMPLES]\n",
        "y_train_hinge_sub = y_train_hinge[:NUM_EXAMPLES]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QMSdgGC1GL7D"
      },
      "source": [
        "このモデルを収束までトレーニングすると、テストセットにおいて 85% を超える精度が達成されます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ya9qP3KkojiM"
      },
      "outputs": [],
      "source": [
        "qnn_history = model.fit(\n",
        "      x_train_tfcirc_sub, y_train_hinge_sub,\n",
        "      batch_size=32,\n",
        "      epochs=EPOCHS,\n",
        "      verbose=1,\n",
        "      validation_data=(x_test_tfcirc, y_test_hinge))\n",
        "\n",
        "qnn_results = model.evaluate(x_test_tfcirc, y_test)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3ER7B7aaojiP"
      },
      "source": [
        "注意: トレーニング精度はエポックの平均値を示します。検証精度はエポックの終了ごとに評価されます。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8952YvuWGL7J"
      },
      "source": [
        "## 3. 従来のニューラルネットワーク\n",
        "\n",
        "量子ニューラルネットワークは、この単純化された MNIST 問題で機能するものの、このタスクでは、従来のニューラルネットワークの性能が QNN をはるかに上回ります。1 つのエポックが終了した時点で、従来のニューラルネットワークは縮小したセットで 98% を超える精度を達成することができます。\n",
        "\n",
        "次の例では、画像をサブサンプリングする代わりに 28x28 の画像を使用した 3 と 6 の分類問題に従来のニューラルネットワークを使用しています。これはほぼ 100% 精度のテストセットに難なく収束します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pZofEHhLGL7L"
      },
      "outputs": [],
      "source": [
        "def create_classical_model():\n",
        "    # A simple model based off LeNet from https://keras.io/examples/mnist_cnn/\n",
        "    model = tf.keras.Sequential()\n",
        "    model.add(tf.keras.layers.Conv2D(32, [3, 3], activation='relu', input_shape=(28,28,1)))\n",
        "    model.add(tf.keras.layers.Conv2D(64, [3, 3], activation='relu'))\n",
        "    model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))\n",
        "    model.add(tf.keras.layers.Dropout(0.25))\n",
        "    model.add(tf.keras.layers.Flatten())\n",
        "    model.add(tf.keras.layers.Dense(128, activation='relu'))\n",
        "    model.add(tf.keras.layers.Dropout(0.5))\n",
        "    model.add(tf.keras.layers.Dense(1))\n",
        "    return model\n",
        "\n",
        "\n",
        "model = create_classical_model()\n",
        "model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "              optimizer=tf.keras.optimizers.Adam(),\n",
        "              metrics=['accuracy'])\n",
        "\n",
        "model.summary()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CiAJl7sZojiU"
      },
      "outputs": [],
      "source": [
        "model.fit(x_train,\n",
        "          y_train,\n",
        "          batch_size=128,\n",
        "          epochs=1,\n",
        "          verbose=1,\n",
        "          validation_data=(x_test, y_test))\n",
        "\n",
        "cnn_results = model.evaluate(x_test, y_test)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X5-5BVJaojiZ"
      },
      "source": [
        "上記のモデルには約 120 万個のパラメータがあります。より公正な比較を行うために、サブサンプリングした画像で 37 個のパラメータモデルを使用してみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "70TOM6r-ojiZ"
      },
      "outputs": [],
      "source": [
        "def create_fair_classical_model():\n",
        "    # A simple model based off LeNet from https://keras.io/examples/mnist_cnn/\n",
        "    model = tf.keras.Sequential()\n",
        "    model.add(tf.keras.layers.Flatten(input_shape=(4,4,1)))\n",
        "    model.add(tf.keras.layers.Dense(2, activation='relu'))\n",
        "    model.add(tf.keras.layers.Dense(1))\n",
        "    return model\n",
        "\n",
        "\n",
        "model = create_fair_classical_model()\n",
        "model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "              optimizer=tf.keras.optimizers.Adam(),\n",
        "              metrics=['accuracy'])\n",
        "\n",
        "model.summary()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lA_Fx-8gojid"
      },
      "outputs": [],
      "source": [
        "model.fit(x_train_bin,\n",
        "          y_train_nocon,\n",
        "          batch_size=128,\n",
        "          epochs=20,\n",
        "          verbose=2,\n",
        "          validation_data=(x_test_bin, y_test))\n",
        "\n",
        "fair_nn_results = model.evaluate(x_test_bin, y_test)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RH3mam7EGL7N"
      },
      "source": [
        "## 4. 比較\n",
        "\n",
        "解像度の高い入力とより強力なモデルの場合、CNN はこの問題を簡単に解決できますが、似たようなパワー（最大 32 個のパラメータ）を持つ古典的モデルはわずかな時間で似たような精度までトレーニングすることができます。いずれにせよ、従来のニューラルネットワークは量子ニューラルネットワークの性能を簡単に上回ります。古典的なデータでは、従来のニューラルネットワークを上回るのは困難といえます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NOMeN7pMGL7P"
      },
      "outputs": [],
      "source": [
        "qnn_accuracy = qnn_results[1]\n",
        "cnn_accuracy = cnn_results[1]\n",
        "fair_nn_accuracy = fair_nn_results[1]\n",
        "\n",
        "sns.barplot([\"Quantum\", \"Classical, full\", \"Classical, fair\"],\n",
        "            [qnn_accuracy, cnn_accuracy, fair_nn_accuracy])"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "mnist.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
